Skip to content

LTX2 Performance tuning#385

Open
prishajain1 wants to merge 1 commit intomainfrom
prisha/ltx2_opt
Open

LTX2 Performance tuning#385
prishajain1 wants to merge 1 commit intomainfrom
prisha/ltx2_opt

Conversation

@prishajain1
Copy link
Copy Markdown
Collaborator

@prishajain1 prishajain1 commented Apr 19, 2026

This PR adds features which lead to performance gains for LTX-2 model, along with a fix for the broken LTX-2 Upsampler in main

Performance enhancement features:

Experiments performed on v7x-8

  • Sharding fix in NNXSImpleFeedForward : Data is sharded along the sequence dimension, each device holds a subset of tokens, but full feature channels. Because the input data had replicated features but the weights expected sharded features on the same physical axis (context), XLA was forced to insert an All-Gather on the sequence dimension to resolve the layout conflict, resulting in high wasted time. With our fix:

    • Overall % of wasted time in all-gathers went from 52.56% to 38.07%
    • Generation time per video dropped from 20s to 16.7s
  • QKV Projection Sharding fix (ironwood specific): The profiling showed that the input data was being All-Gathered along the sequence dimension triggered by the QKV Projection step. Because the weights were sharded on the dimension that needed to be summed over (features), a single device could not complete the matrix multiplication using only its local shard of the data. To resolve this, XLA automatically inserted an All-Gather to replicate the sequence dimension across all devices before performing the multiplication. We changed the weight sharding in attention_ltx2.py to remove sharding on the input feature dimension.

    • Overall % of wasted time in all-gathers went from 38.07% to 19.39%
    • Generation time per video dropped from 16.7 to 13.84s
  • Batching in text encoder: With CFG enabled, we see two passes of text encoder: one each for positive and negative prompts. If Classifier-Free Guidance is enabled, we concatenate the positive prompt and negative prompt and instead of doing two passes of text encoder, we do a single pass.

    • Text encoder time reduced from 3.54s to 3.06s
    • Generation time per video dropped from 13.84s to 13.38s
  • JITting Diffusion Loop: The current implementation uses a Python for loop to iterate over diffusion timesteps. This created a "Python Dispatch Wall," resulting in some idle time between consecutive forward passes while the TPU waited for the host CPU to dispatch the next step. We refactored the entire denoising loop to use nnx.scan.

    • The total diffusion time across 40 steps dropped from 7.84s to 7.28s
    • Generation time per video dropped from 13.38s to 12.5s

LTX2 Upsampler fix:

  • The current LTX2 Upsampler pipeline raises ValueError : blur_down is the name of the submodule in the PyTorch state dict from the Hugging Face checkpoint. In the original PyTorch model, that layer was named blur_down, but in the MaxDiffusion Flax implementation, it was named blur. Because our weight loader didn't rename it, nnx.update tried to update a non-existent blur_down attribute.

Results

v7x-8

Version Execution Time Status
Current Main 20.01s Video Link
After Fix 12.50s Video Link

We also tested WAN I2V pipelines to ensure no regressions are caused there. No quality regression or increased latency was observed.

@prishajain1 prishajain1 requested a review from entrpn as a code owner April 19, 2026 07:11
@github-actions
Copy link
Copy Markdown

@prishajain1 prishajain1 marked this pull request as draft April 19, 2026 09:00
Comment thread src/maxdiffusion/configs/ltx2_video.yml
Comment thread src/maxdiffusion/models/ltx2/attention_ltx2.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Comment thread src/maxdiffusion/models/ltx2/attention_ltx2.py Outdated
@prishajain1 prishajain1 marked this pull request as ready for review April 20, 2026 12:30
Comment thread src/maxdiffusion/models/ltx2/attention_ltx2.py
@prishajain1 prishajain1 force-pushed the prisha/ltx2_opt branch 2 times, most recently from efbbdc8 to 79fd839 Compare April 20, 2026 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants